LstmGradWeight

计算 LSTM 网络的权重梯度,包括输入权重 W、隐藏权重 U 和偏置 b 的反向传播梯度。

\[\begin{split}dH_t &= dY_t + dH_{t+1} \\ dC_t &= dC_{t+1} \odot f_{t+1} + dH_t \odot o_t \odot (1 - \tanh^2(C_t)) \\ dX_t &= dA_t \cdot W^T \\ dW &= \sum_t dA_t \cdot X_t^T \\ dU &= \sum_t dA_t \cdot H_{t-1}^T \\ dA_t &= dH_t \odot o_t \odot (1 - \tanh^2(C_t)) \odot g'_t\end{split}\]

其中:

  • (dH_t) 表示隐藏状态的梯度。

  • (dC_t) 表示细胞状态的梯度。

  • (dX_t) 表示输入梯度。

  • (dW, dU) 分别表示输入权重和隐藏状态权重的梯度。

  • (dA_t) 表示门控单元梯度。

  • (f_t, o_t, C_t, g_t) 分别为遗忘门、输出门、细胞状态、输入门的前向值。

  • (odot) 表示元素逐乘。

输入:
  • params - 静态参数数组,包含 LSTM 网络配置、权重、状态指针等。

  • dynamic_params - 动态参数数组,用于存储运行时指针及中间梯度。

输出:
  • dX_ - 输入梯度。

  • dH_ - 隐藏状态梯度。

  • dC_ - 细胞状态梯度。

  • dA_tmp_ - 门控梯度中间缓存。

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持 fp

  • MT7004 支持 hp, fp

共享存储版本:

void fp_lstmgradweight_s(long long *params, long long *dynamic_params, int core_mask)
void hp_lstmgradweight_s(long long *params, long long *dynamic_params, int core_mask)

C调用示例:

 1#include <stdio.h>
 2#include <lstmgradweight.h>
 3
 4int main() {
 5    long long params[32];
 6    long long dynamic_params[32];
 7    int core_mask = 0xff;
 8
 9    // 初始化 params 和 dynamic_params
10    fp_lstmgradweight_s(params, dynamic_params, core_mask);
11    return 0;
12}

私有存储版本:

void fp_lstmgradweight_p(long long *params, long long *dynamic_params)
void hp_lstmgradweight_p(long long *params, long long *dynamic_params)

C调用示例:

 1#include <stdio.h>
 2#include <lstmgradweight.h>
 3
 4int main() {
 5    long long params[32];
 6    long long dynamic_params[32];
 7
 8    // 初始化 params 和 dynamic_params
 9    fp_lstmgradweight_p(params, dynamic_params);
10    return 0;
11}